from ActualCausal.Train.Inter.train_trace import train_binaries
from ActualCausal.Inference.General.null import infer_null_values
import numpy as np

def train_null_bin(args, params, model, buffer, form="full", log_batch=[], additional=[], name = "", itr_num=0, intermediate_logger = None):
    # get the target binaries to train for (the traces, or a proxy like gradient, proximity, etc.)
    # then call train_binaries to train the interaction model to predict those values
    null_traces, null_weights, null_results = infer_null_values(args, params, model, buffer, keep_all=True, infer_names=[name] if len(name) > 0 else [])
    if form == "full": 
        form = "probs"
        null_traces = null_traces[name]
        null_weights = np.mean(null_weights[name], axis=-1)
        params.trace_weights = null_weights / np.sum(null_weights)
    else:
        form = "all_probs"
        null_weights = np.mean(null_weights[name], axis=-1)
        params.trace_weights = null_weights / np.sum(null_weights)
        
    return train_binaries(args, params, model, buffer, null_traces, form=form, log_batch=log_batch, additional=additional, itr_num=itr_num, intermediate_logger = intermediate_logger)